""""EM algorithm for weighted LRA (Srebro & Jaakkola, 2003)

https://aaai-22.aaai.org/Library/ICML/2003/icml03-094.php
"""

import numpy as np

from . import svd


def weighted_loss(matrix_1, matrix_2, weight):
    sqrt_weight = np.sqrt(weight)
    diff = sqrt_weight * (matrix_1 - matrix_2)
    return np.linalg.norm(diff)**2


def weighted_lra(
    matrix, weight, rank, epochs=25, initial_solution=None, return_history=False
):
    """Weighted low rank approximation via expectation-maximization."""
    max_weight = np.max(weight)
    weight = weight / max_weight
    if initial_solution is None:
        iterate = np.zeros(matrix.shape)
    else:
        iterate = initial_solution
    if return_history:
        history = []
    for _ in range(epochs):
        left_factor, right_factor = svd.svd(
            weight * matrix + (1 - weight) * iterate, rank=rank
        )
        iterate = left_factor @ right_factor
        if return_history:
            history.append(weighted_loss(matrix, iterate, max_weight * weight))
    if return_history:
        return left_factor, right_factor, history
    else:
    	return left_factor, right_factor
